{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Unigram tokenization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = [\n",
    "    \"This is the Hugging Face Course.\",\n",
    "    \"This chapter is about tokenization.\",\n",
    "    \"This section shows several tokenizer algorithms.\",\n",
    "    \"Hopefully, you will be able to understand how they are trained and generate tokens.\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"xlnet-base-cased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "word_freqs = defaultdict(int)\n",
    "for text in corpus:\n",
    "    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)\n",
    "    new_words = [word for word, offset in words_with_offsets]\n",
    "    for word in new_words:\n",
    "        word_freqs[word] += 1\n",
    "\n",
    "word_freqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('▁t', 7), ('is', 5), ('er', 5), ('▁a', 5), ('▁to', 4), ('to', 4), ('en', 4), ('▁T', 3), ('▁Th', 3), ('▁Thi', 3)]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "char_freqs = defaultdict(int)\n",
    "subwords_freqs = defaultdict(int)\n",
    "for word, freq in word_freqs.items():\n",
    "    for i in range(len(word)):\n",
    "        char_freqs[word[i]] += freq\n",
    "        # Loop through the subwords of length at least 2\n",
    "        for j in range(i + 2, len(word) + 1):\n",
    "            subwords_freqs[word[i:j]] += freq\n",
    "\n",
    "# Sort subwords by frequency\n",
    "sorted_subwords = sorted(subwords_freqs.items(), key=lambda x: x[1], reverse=True)\n",
    "sorted_subwords[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_freqs = list(char_freqs.items()) + sorted_subwords[: 300 - len(char_freqs)]\n",
    "token_freqs = {token: freq for token, freq in token_freqs}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import log\n",
    "\n",
    "total_sum = sum([freq for token, freq in token_freqs.items()])\n",
    "model = {token: -log(freq / total_sum) for token, freq in token_freqs.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_word(word, model):\n",
    "    best_segmentations = [{\"start\": 0, \"score\": 1}] + [\n",
    "        {\"start\": None, \"score\": None} for _ in range(len(word))\n",
    "    ]\n",
    "    for start_idx in range(len(word)):\n",
    "        # This should be properly filled by the previous steps of the loop\n",
    "        best_score_at_start = best_segmentations[start_idx][\"score\"]\n",
    "        for end_idx in range(start_idx + 1, len(word) + 1):\n",
    "            token = word[start_idx:end_idx]\n",
    "            if token in model and best_score_at_start is not None:\n",
    "                score = model[token] + best_score_at_start\n",
    "                # If we have found a better segmentation ending at end_idx, we update\n",
    "                if (\n",
    "                    best_segmentations[end_idx][\"score\"] is None\n",
    "                    or best_segmentations[end_idx][\"score\"] > score\n",
    "                ):\n",
    "                    best_segmentations[end_idx] = {\"start\": start_idx, \"score\": score}\n",
    "\n",
    "    segmentation = best_segmentations[-1]\n",
    "    if segmentation[\"score\"] is None:\n",
    "        # We did not find a tokenization of the word -> unknown\n",
    "        return [\"<unk>\"], None\n",
    "\n",
    "    score = segmentation[\"score\"]\n",
    "    start = segmentation[\"start\"]\n",
    "    end = len(word)\n",
    "    tokens = []\n",
    "    while start != 0:\n",
    "        tokens.insert(0, word[start:end])\n",
    "        next_start = best_segmentations[start][\"start\"]\n",
    "        end = start\n",
    "        start = next_start\n",
    "    tokens.insert(0, word[start:end])\n",
    "    return tokens, score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(['H', 'o', 'p', 'e', 'f', 'u', 'll', 'y'], 41.5157494601402)\n",
       "(['This'], 6.288267030694535)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(encode_word(\"Hopefully\", model))\n",
    "print(encode_word(\"This\", model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_loss(model):\n",
    "    loss = 0\n",
    "    for word, freq in word_freqs.items():\n",
    "        _, word_loss = encode_word(word, model)\n",
    "        loss += freq * word_loss\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "413.10377642940875"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compute_loss(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "\n",
    "def compute_scores(model):\n",
    "    scores = {}\n",
    "    model_loss = compute_loss(model)\n",
    "    for token, score in model.items():\n",
    "        # We always keep tokens of length 1\n",
    "        if len(token) == 1:\n",
    "            continue\n",
    "        model_without_token = copy.deepcopy(model)\n",
    "        _ = model_without_token.pop(token)\n",
    "        scores[token] = compute_loss(model_without_token) - model_loss\n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6.376412403623874\n",
       "0.0"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores = compute_scores(model)\n",
    "print(scores[\"ll\"])\n",
    "print(scores[\"his\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "percent_to_remove = 0.1\n",
    "while len(model) > 100:\n",
    "    scores = compute_scores(model)\n",
    "    sorted_scores = sorted(scores.items(), key=lambda x: x[1])\n",
    "    # Remove percent_to_remove tokens with the lowest scores.\n",
    "    for i in range(int(len(model) * percent_to_remove)):\n",
    "        _ = token_freqs.pop(sorted_scores[i][0])\n",
    "\n",
    "    total_sum = sum([freq for token, freq in token_freqs.items()])\n",
    "    model = {token: -log(freq / total_sum) for token, freq in token_freqs.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁This', '▁is', '▁the', '▁Hugging', '▁Face', '▁', 'c', 'ou', 'r', 's', 'e', '.']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def tokenize(text, model):\n",
    "    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)\n",
    "    pre_tokenized_text = [word for word, offset in words_with_offsets]\n",
    "    encoded_words = [encode_word(word, model)[0] for word in pre_tokenized_text]\n",
    "    return sum(encoded_words, [])\n",
    "\n",
    "\n",
    "tokenize(\"This is the Hugging Face course.\", model)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Unigram tokenization",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
